{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "# ^^^ pyforest auto-imports - don't write above this line\n",
    "import sys\n",
    "import warnings\n",
    "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
    "from tqdm.auto import tqdm\n",
    "from copy import deepcopy\n",
    "from torch.nn.parallel import DistributedDataParallel as DDP, DataParallel as DP\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import bib_lookup\n",
    "except ModuleNotFoundError:\n",
    "    sys.path.insert(0, \"/home/wenhao/Jupyter/wenhao/workspace/bib_lookup/\")\n",
    "try:\n",
    "    from torch_ecg.utils.misc import MovingAverage, list_sum\n",
    "except ModuleNotFoundError:\n",
    "    sys.path.insert(0, \"/home/wenhao/Jupyter/wenhao/workspace/torch_ecg/\")\n",
    "    from torch_ecg.utils.misc import MovingAverage, list_sum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cfg import TrainCfg, ModelCfg\n",
    "from trainer import CPSC2019Trainer, _MODEL_MAP\n",
    "from model import ECG_SEQ_LAB_NET_CPSC2019, ECG_UNET_CPSC2019, ECG_SUBTRACT_UNET_CPSC2019\n",
    "from dataset import CPSC2019"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TrainCfg.db_dir = \"/home/wenhao/Jupyter/wenhao/data/CPSC2019/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_train = CPSC2019(TrainCfg, training=True, lazy=False)\n",
    "ds_val = CPSC2019(TrainCfg, training=False, lazy=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## train CNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_config = deepcopy(TrainCfg)\n",
    "train_config.model_name = \"seq_lab_cnn\"\n",
    "\n",
    "model_config = deepcopy(ModelCfg[train_config.model_name])\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "# device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ECG_SEQ_LAB_NET_CPSC2019(\n",
    "    n_leads=train_config.n_leads,\n",
    "    config=model_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.module_size_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if torch.cuda.device_count() > 1:\n",
    "#     model = DP(model)\n",
    "\n",
    "model.to(device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "trainer = CPSC2019Trainer(\n",
    "    model=model,\n",
    "    model_config=model_config,\n",
    "    train_config=train_config,\n",
    "    device=device,\n",
    "    lazy=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer._setup_dataloaders(ds_train, ds_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bmd = trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del bmd, trainer, model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## train CRNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_config = deepcopy(TrainCfg)\n",
    "train_config.model_name = \"seq_lab_crnn\"\n",
    "\n",
    "model_config = deepcopy(ModelCfg[train_config.model_name])\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = ECG_SEQ_LAB_NET_CPSC2019(\n",
    "    n_leads=train_config.n_leads,\n",
    "    config=model_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.module_size_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if torch.cuda.device_count() > 1:\n",
    "#     model = DP(model)\n",
    "\n",
    "model.to(device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = CPSC2019Trainer(\n",
    "    model=model,\n",
    "    model_config=model_config,\n",
    "    train_config=train_config,\n",
    "    device=device,\n",
    "    lazy=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer._setup_dataloaders(ds_train, ds_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bmd = trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del bmd, trainer, model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train U-Net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_config = deepcopy(TrainCfg)\n",
    "train_config.model_name = \"unet\"\n",
    "\n",
    "model_config = deepcopy(ModelCfg[train_config.model_name])\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "# device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model = _MODEL_MAP[train_config.model_name](\n",
    "    n_leads=train_config.n_leads,\n",
    "    config=model_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.module_size_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if torch.cuda.device_count() > 1:\n",
    "#     model = DP(model)\n",
    "\n",
    "model.to(device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "trainer = CPSC2019Trainer(\n",
    "    model=model,\n",
    "    model_config=model_config,\n",
    "    train_config=train_config,\n",
    "    device=device,\n",
    "    lazy=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer._setup_dataloaders(ds_train, ds_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "bmd = trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "del bmd, trainer, model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# clear GPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# gather results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "from matplotlib.pyplot import cm\n",
    "import matplotlib.patches as patches\n",
    "\n",
    "sns.set()\n",
    "\n",
    "plt.rcParams['xtick.labelsize']=28\n",
    "plt.rcParams['ytick.labelsize']=28\n",
    "plt.rcParams['axes.labelsize']=40\n",
    "plt.rcParams['legend.fontsize']=24\n",
    "\n",
    "colors = plt.rcParams['axes.prop_cycle'].by_key()['color']\n",
    "\n",
    "markers = [\"p\", \"v\", \"s\", \"d\", \"x\", \"*\", \"+\", \"$\\heartsuit$\"]\n",
    "marker_size = 12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch_ecg.utils.misc import MovingAverage, list_sum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ma = MovingAverage()\n",
    "\n",
    "ma = lambda x: x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_cnn = pd.read_csv(\"./results/TorchECG_12-22_16-52_ECG_SEQ_LAB_NET_CPSC2019_adamw_amsgrad_LR_0.001_BS_32_multi_scopic.csv\")\n",
    "df_crnn = pd.read_csv(\"./results/TorchECG_12-22_17-13_ECG_SEQ_LAB_NET_CPSC2019_adamw_amsgrad_LR_0.001_BS_32_multi_scopic.csv\")\n",
    "df_unet = pd.read_csv(\"./results/TorchECG_12-22_17-48_ECG_UNET_CPSC2019_adamw_amsgrad_LR_0.001_BS_32_none.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_cnn_train = df_cnn[df_cnn.part==\"train\"].dropna(subset=[\"qrs_score\"])\n",
    "df_crnn_train = df_crnn[df_crnn.part==\"train\"].dropna(subset=[\"qrs_score\"])\n",
    "df_unet_train = df_unet[df_unet.part==\"train\"].dropna(subset=[\"qrs_score\"])\n",
    "df_cnn_val = df_cnn[df_cnn.part==\"val\"].dropna(subset=[\"qrs_score\"])\n",
    "df_crnn_val = df_crnn[df_crnn.part==\"val\"].dropna(subset=[\"qrs_score\"])\n",
    "df_unet_val = df_unet[df_unet.part==\"val\"].dropna(subset=[\"qrs_score\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_lr = df_unet[df_unet.part==\"train\"][[\"step\", \"lr\"]].dropna(subset=[\"lr\"])\n",
    "df_lr[\"step\"] = df_lr[\"step\"]/(1600/32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(24,16))\n",
    "\n",
    "line_width = 4\n",
    "\n",
    "spacing = 2\n",
    "\n",
    "lines = []\n",
    "\n",
    "lines.append(ax.plot(\n",
    "    df_crnn_train.epoch.values[::spacing], ma(df_crnn_train.qrs_score.values)[::spacing],\n",
    "    marker=markers[0], markersize=marker_size, linewidth=line_width, color=colors[0], label=\"crnn-train\",\n",
    "))\n",
    "lines.append(ax.plot(\n",
    "    df_cnn_train.epoch.values[::spacing], ma(df_cnn_train.qrs_score.values)[::spacing],\n",
    "    marker=markers[1], markersize=marker_size, linewidth=line_width, color=colors[1], label=\"cnn-train\",\n",
    "))\n",
    "\n",
    "\n",
    "lines.append(ax.plot(\n",
    "    df_unet_train.epoch.values[::spacing], ma(df_unet_train.qrs_score.values)[::spacing],\n",
    "    marker=markers[2], markersize=marker_size, linewidth=line_width, color=colors[3], label=\"unet-train\",\n",
    "))\n",
    "lines.append(ax.plot(\n",
    "    df_crnn_train.epoch.values[::spacing], ma(df_crnn_val.qrs_score.values)[::spacing],\n",
    "    marker=markers[0], markersize=marker_size, linewidth=line_width, color=colors[0], ls=\"--\", label=\"crnn-val\",\n",
    "))\n",
    "lines.append(ax.plot(\n",
    "    df_cnn_train.epoch.values[::spacing], ma(df_cnn_val.qrs_score.values)[::spacing],\n",
    "    marker=markers[1], markersize=marker_size, linewidth=line_width, color=colors[1], ls=\"--\", label=\"cnn-val\",\n",
    "))\n",
    "lines.append(ax.plot(\n",
    "    df_unet_train.epoch.values[::spacing], ma(df_unet_val.qrs_score.values)[::spacing],\n",
    "    marker=markers[2], markersize=marker_size, linewidth=line_width, color=colors[3], ls=\"--\", label=\"unet-val\",\n",
    "))\n",
    "ax.set_ylim(0.55,1.05)\n",
    "ax.set_xlabel(\"Epochs (n.u.)\", fontsize=36)\n",
    "ax.set_ylabel(\"QRS score (n.u.)\", fontsize=36)\n",
    "\n",
    "ax2 = ax.twinx()\n",
    "lines.append(ax2.plot(\n",
    "    df_lr.step.values, ma(df_lr.lr.values),\n",
    "    color=colors[2], ls=\":\", linewidth=8, label=\"learning rate\",\n",
    "))\n",
    "ax2.set_ylabel(r\"$10^3\\times$\"+\"Learning Rate (n.u.)\")\n",
    "ax2.set_ylim(-0.00025,0.00225)\n",
    "ax2.set_yticks(np.arange(0,0.0025,0.0005).tolist())\n",
    "ax2.set_yticklabels([f\"{n:.1f}\" for n in np.arange(0,0.0025,0.0005)*1000])\n",
    "\n",
    "# ax.text(40,1.01, \"max lr = 0.02\", fontsize=30)\n",
    "# ax.text(5,0.61, f\"start lr = {df_lr.lr.values[0]:.5f}\", fontsize=30)\n",
    "\n",
    "lns = list_sum(lines[:-1])\n",
    "labs = [l.get_label() for l in lns]\n",
    "ax.legend(lns, labs, loc=\"lower right\", ncol=2, fontsize=30, bbox_to_anchor=(0.95, 0.1))\n",
    "\n",
    "ax2.legend(loc=\"upper left\", fontsize=30)\n",
    "ax2.set_xlim(-5,105);\n",
    "\n",
    "plt.savefig(\"./results/cpsc2019_nn_compare.svg\", dpi=1200, bbox_inches=\"tight\", transparent=False);\n",
    "plt.savefig(\"./results/cpsc2019_nn_compare.pdf\", dpi=1200, bbox_inches=\"tight\", transparent=False);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
